﻿using System;
using System.Collections.Generic;
using System.Linq;
using System.Linq.Expressions;
using System.Reflection;
using VisitorDelegate = System.Func<System.Linq.Expressions.Expression, System.Linq.Expressions.Expression>;
using System.Diagnostics;

namespace QueryInterception
{
   

    /// 
    /// 
    /// <remarks>http://blogs.msdn.com/b/alexj/archive/2010/03/01/tip-55-how-to-extend-an-iqueryable-by-wrapping-it.aspx</remarks>
    public class InterceptingProvider : IQueryProvider
    {
        readonly IQueryProvider _underlyingProvider;
        readonly VisitorDelegate[] _visitors;
        readonly VisitorDelegate _afterUnderlyingVisitor;

        private InterceptingProvider(VisitorDelegate afterUnderlyingVisitor,
            IQueryProvider underlyingQueryProvider,
            params VisitorDelegate[] visitors)
        {
            this._underlyingProvider = underlyingQueryProvider;
            this._afterUnderlyingVisitor = afterUnderlyingVisitor;
            this._visitors = visitors;
        }


        public static IQueryable<T> Intercept<T>(
            ExpressionVisitor afterUnderlyingVisitor,
            IQueryable<T> underlyingQuery,
            params ExpressionVisitor[] visitors)
        {
            Func<Expression, Expression>[] visitFuncs =
                visitors
                .Select(v => (Func<Expression, Expression>)v.Visit)
                .ToArray();
            VisitorDelegate afterDelegate = afterUnderlyingVisitor != null ? (VisitorDelegate)afterUnderlyingVisitor.Visit : null;
            return Intercept<T>(afterDelegate, underlyingQuery, visitFuncs);
        }
        public static IQueryable<T> Intercept<T>(
            IQueryable<T> underlyingQuery,
            params ExpressionVisitor[] visitors)
        {
            Func<Expression, Expression>[] visitFuncs =
                visitors
                .Select(v => (Func<Expression, Expression>)v.Visit)
                .ToArray();
            return Intercept<T>(null, underlyingQuery, visitFuncs);
        }

        public static IQueryable<T> Intercept<T>(Func<Expression, Expression> afterUnderlyingVisitor,
            IQueryable<T> underlyingQuery,
            params Func<Expression, Expression>[] visitors)
        {
            var provider = new InterceptingProvider(afterUnderlyingVisitor,
                underlyingQuery.Provider,
                visitors
            );
            return provider.CreateQuery<T>(
                underlyingQuery.Expression);
        }

        public static bool DoTrace { get; set; }
        public IEnumerator<TElement> ExecuteQuery<TElement>(
            Expression expression)
        {
            Expression intercepted;
            using (var step = Profiler.Step("intercepting query"))
            {
                intercepted = InterceptExpr(expression);
            }
            IQueryable newExpression;
            using (var step = Profiler.Step("Ef Translating query"))
            {
                newExpression = _underlyingProvider.CreateQuery(intercepted);
            }

            Debug.Assert(intercepted.Type.FullName.Contains("Shared") == false);
            if (DoTrace)
                using (var step = Profiler.Step("ToTraceString"))
                {
                    Trace.WriteLine(((System.Data.Objects.ObjectQuery)newExpression).ToTraceString());
                }
            if (_afterUnderlyingVisitor != null)
            {
                var afterResult = _afterUnderlyingVisitor(newExpression.Expression);
            }
            using (var step = Profiler.Step("enumerating query"))
            {
                var enumerator = newExpression.GetEnumerator();
                var enumeratorType = enumerator.GetType();

                //get the type the enumerator contains
                var sourceArgumentType = enumeratorType.GetGenericArguments().Single();

                var targetType = typeof(TElement);
                if (typeof(IEnumerator<TElement>).IsAssignableFrom(enumeratorType))
                    return (IEnumerator<TElement>)enumerator;


                if (targetType.IsAssignableFrom(sourceArgumentType))
                {
                    var items = new List<TElement>();
                    while (enumerator.MoveNext())
                    {
                        var current = enumerator.Current;

                        items.Add((TElement)current);
                    }
                    if (enumerator is IDisposable)
                    {
                        (enumerator as IDisposable).Dispose();
                    }
                    return items.GetEnumerator();
                }
                //needs to translate one anonymous type to another

                var targetConstructor = targetType.GetConstructor(targetType.GetGenericArguments());
#warning does not handle recursive anonymous types
                var eProperties = from tp in targetType.GetProperties()
                                  join ep in sourceArgumentType.GetProperties()
                                  on tp.Name equals ep.Name
                                  select ep.GetGetMethod();



                var items2 = new List<TElement>();
                while (enumerator.MoveNext())
                {
                    var current = enumerator.Current;
                    var targetParams = eProperties.Select(s => s.Invoke(current, null));
                    var newItem = targetConstructor.Invoke(targetParams.ToArray());

                    items2.Add((TElement)newItem);
                }
                if (enumerator is IDisposable)
                {
                    (enumerator as IDisposable).Dispose();
                }
                return items2.GetEnumerator();
            }
        }

        TResult TranslateAnonymous<TResult>(Type inputType, object source)
        {
            var targetType = typeof(TResult);
            if (targetType.IsAssignableFrom(inputType))
                return (TResult)source;
            var targetConstructor = targetType.GetConstructor(targetType.GetGenericArguments());
#warning does not handle nested anonymous types
            var eProperties = from tp in targetType.GetProperties()
                              join ep in inputType.GetProperties()
                              on tp.Name equals ep.Name
                              select ep.GetGetMethod();
            var targetParams = eProperties.Select(s => s.Invoke(source, null));
            var result = targetConstructor.Invoke(targetParams.ToArray());
            return (TResult)result;
        }

        public IQueryable<TElement> CreateQuery<TElement>(
            Expression expression)
        {
            return new InterceptedQuery<TElement>(this, expression);
        }

        public IQueryable CreateQuery(Expression expression)
        {
            Type et = TypeHelper.FindIEnumerable(expression.Type);
            Type qt = typeof(InterceptedQuery<>).MakeGenericType(et);
            var args = new object[] { this, expression };

            var ci = qt.GetConstructor(
                BindingFlags.NonPublic | BindingFlags.Instance,
                null,
                new Type[] {  
                typeof(InterceptingProvider), 
                typeof(Expression) 
            },
                null);

            return (IQueryable)ci.Invoke(args);
        }

        public TResult Execute<TResult>(Expression expression)
        {
            var intercepted = InterceptExpr(expression);
            var result = this._underlyingProvider.Execute(intercepted);
            if (result == null)
                return default(TResult);
            return TranslateAnonymous<TResult>(result.GetType(), result);
        }

        public object Execute(Expression expression)
        {
            return this._underlyingProvider.Execute(
                InterceptExpr(expression)
            );
        }

        private Expression InterceptExpr(Expression expression)
        {
            Expression exp = expression;
            foreach (var visitor in _visitors)
                exp = visitor(exp);
            return exp;
        }
    }

}
